# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import asyncio
import random
import string
from typing import AsyncGenerator, List, Union
from unittest import mock

import pytest
import torch

from megatron.core import parallel_state
from megatron.core.inference.contexts import StaticInferenceContext
from megatron.core.inference.engines import StaticInferenceEngine
from megatron.core.inference.inference_request import (
    DynamicInferenceRequestRecord,
    InferenceRequest,
    Status,
)
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
    GPTInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
    InferenceWrapperConfig,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
    TextGenerationController,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_fa_min_version
from tests.unit_tests.test_utilities import Utils


class StaticInferenceEngineTestHarness:
    def setup_engine(
        self,
        engine_max_batch_size=None,
        vocab_size=100,
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        expert_model_parallel_size=1,
        sequence_parallel=False,
        legacy=False,
        buffer_size_gb=10,
        inference_config_params_dtype=torch.float,
    ):
        Utils.initialize_model_parallel(
            tensor_model_parallel_size=tensor_model_parallel_size,
            pipeline_model_parallel_size=pipeline_model_parallel_size,
        )

        model_parallel_cuda_manual_seed(123)
        self.batch_size = 4
        self.hidden_size = 32
        self.vocab_size = vocab_size
        self.sequence_length = 64
        transformer_config = TransformerConfig(
            num_layers=4,
            hidden_size=self.hidden_size,
            num_attention_heads=4,
            use_cpu_initialization=True,
            inference_rng_tracker=True,
            tensor_model_parallel_size=tensor_model_parallel_size,
            pipeline_model_parallel_size=pipeline_model_parallel_size,
            expert_model_parallel_size=expert_model_parallel_size,
            num_moe_experts=None if expert_model_parallel_size == 1 else expert_model_parallel_size,
            sequence_parallel=sequence_parallel,
            pipeline_dtype=inference_config_params_dtype,
            params_dtype=inference_config_params_dtype,
            add_bias_linear=expert_model_parallel_size == 1,
        )

        gpt_model = GPTModel(
            config=transformer_config,
            transformer_layer_spec=get_gpt_layer_local_spec(),
            vocab_size=self.vocab_size,
            max_sequence_length=self.sequence_length,
            parallel_output=True,
            pre_process=parallel_state.is_pipeline_first_stage(),
            post_process=parallel_state.is_pipeline_last_stage(),
        ).cuda()
        gpt_model.to(inference_config_params_dtype)

        inference_wrapper_config = InferenceWrapperConfig(
            hidden_size=self.hidden_size,
            inference_batch_times_seqlen_threshold=400,
            inference_max_requests=self.batch_size,
            fp32_residual_connection=False,
            params_dtype=inference_config_params_dtype,
            padded_vocab_size=self.vocab_size,
        )

        inference_context = StaticInferenceContext.from_config(inference_wrapper_config)

        inference_wrapped_model = GPTInferenceWrapper(
            gpt_model, inference_wrapper_config, inference_context
        )
        self.mock_tokenizer = mock.Mock()
        # Set required tokenizer attributes before engine creation
        self.mock_tokenizer.vocab_size = self.vocab_size
        self.mock_tokenizer.eod = self.vocab_size - 1
        text_generation_controller = TextGenerationController(
            inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
        )

        if engine_max_batch_size is not None and engine_max_batch_size > self.batch_size:
            with pytest.warns(UserWarning):
                self.static_engine = StaticInferenceEngine(
                    text_generation_controller=text_generation_controller,
                    max_batch_size=engine_max_batch_size,
                    legacy=legacy,
                    buffer_size_gb=buffer_size_gb,
                )
        else:
            self.static_engine = StaticInferenceEngine(
                text_generation_controller=text_generation_controller,
                max_batch_size=engine_max_batch_size,
                legacy=legacy,
                buffer_size_gb=buffer_size_gb,
            )

    def teardown_method(self, method):
        Utils.destroy_model_parallel()


class TestStaticInferenceEngine(StaticInferenceEngineTestHarness):
    @pytest.mark.parametrize(
        "batch_size,num_trials,empty_prompt",
        [(4, 1, False), (4, 1, True), (4, 3, False), (2, 1, False), (8, 1, False)],
    )
    def test_generate_legacy_static(self, batch_size: int, num_trials: int, empty_prompt: bool):
        self.setup_engine(engine_max_batch_size=batch_size, legacy=True)
        # Generating random length integer prompts
        self.mock_tokenizer.tokenize.return_value = [
            random.randint(0, self.vocab_size - 1) for _ in range(random.randint(5, 10))
        ]
        # Generates some random string
        self.mock_tokenizer.detokenize.return_value = ''.join(
            random.choices(string.ascii_letters, k=random.randint(4, 10))
        )

        for _ in range(num_trials):
            if empty_prompt:
                prompts = ["" for i in range(batch_size)]
            else:
                prompts = ["sample" * (i + 1) for i in range(batch_size)]
            results: List[InferenceRequest] = self.static_engine.generate(
                prompts, sampling_params=SamplingParams(num_tokens_to_generate=10)
            )

            assert len(results) == batch_size
            for result in results:
                assert (
                    result.status == Status.COMPLETED
                ), f"Status should be completed but its {result.status}"
                assert result.generated_length > 0, f"Generated length should be greater than zero"
                assert result.generated_text is not None, f'Generated text should not be None'

    @pytest.mark.skipif(
        not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
    )
    @pytest.mark.parametrize(
        "batch_size,num_trials,empty_prompt",
        [(4, 1, False), (4, 1, True), (4, 3, False), (2, 1, False), (8, 1, False)],
    )
    def test_generate_dynamic(self, batch_size: int, num_trials: int, empty_prompt: bool):
        self.setup_engine(
            engine_max_batch_size=batch_size,
            legacy=False,
            buffer_size_gb=2,
            inference_config_params_dtype=torch.bfloat16,
        )
        # Generating random length integer prompts
        self.mock_tokenizer.tokenize.return_value = [
            random.randint(0, self.vocab_size - 1) for _ in range(random.randint(5, 10))
        ]
        # Generates some random string
        self.mock_tokenizer.detokenize.return_value = ''.join(
            random.choices(string.ascii_letters, k=random.randint(4, 10))
        )
        assert hasattr(self.static_engine, 'dynamic_engine'), "Dynamic engine not initialized"
        assert (
            self.static_engine.legacy is False
        ), "Using legacy static engine when it should be using dynamic engine"

        for _ in range(num_trials):
            if empty_prompt:
                prompts = ["" for i in range(batch_size)]
            else:
                prompts = ["sample" * (i + 1) for i in range(batch_size)]
            results: List[Union[InferenceRequest, DynamicInferenceRequestRecord]] = (
                self.static_engine.generate(
                    prompts, sampling_params=SamplingParams(num_tokens_to_generate=10)
                )
            )

            assert len(results) == batch_size
            for result in results:
                if isinstance(result, DynamicInferenceRequestRecord):
                    result = result.merge()
                assert isinstance(result, InferenceRequest), (
                    "expected <InferenceRequest>; found <%s>." % type(result).__name__
                )
                assert (
                    result.status == Status.COMPLETED
                ), f"Status should be completed but its {result.status}"
                assert result.generated_length > 0, f"Generated length should be greater than zero"
                assert result.generated_text is not None, f'Generated text should not be None'

    @pytest.mark.asyncio
    async def test_streaming(self):
        self.setup_engine(legacy=True)

        async def collect_stream(stream_generator, num_tokens_to_generate):
            prev_log_probs = None
            prev_text = ""
            prev_idx = 0
            prev_length = 0
            num_output_tokens = 0
            async for output in stream_generator:
                num_output_tokens += 1
                assert isinstance(
                    output, InferenceRequest
                ), f"Expected InferenceRequest, got {type(output)}"
                assert output.generated_log_probs is not None, f"Expected log probs tensor"
                assert (
                    output.generated_tokens.shape[0] == output.generated_length
                ), f"Expected log probs length to match # generated tokens"
                assert (
                    len(output.generated_log_probs) == output.generated_length
                ), f"Expected log probs length to match # generated tokens"
                assert output.generated_length > prev_length, f"Expected generated length to grow"
                assert (
                    output.generated_text[:prev_idx] == prev_text
                ), f"Expected generated text to match previous text"
                assert (
                    prev_log_probs is None or prev_log_probs == output.generated_log_probs[:-1]
                ), f"Expected previous log probs to match new log probs"
                prev_length = output.generated_length
                prev_text = output.generated_text
                prev_idx = len(output.generated_text)
                prev_log_probs = output.generated_log_probs

            assert (
                num_output_tokens == num_tokens_to_generate
            ), f"Should have streamed {num_tokens_to_generate} tokens but actually streamed {num_output_tokens}"
            assert (
                len(output.generated_tokens) == num_tokens_to_generate
            ), f"Should have included {num_tokens_to_generate} tokens but actually returned {len(output.generated_tokens)}"
            assert (
                len(output.generated_log_probs) == num_tokens_to_generate
            ), f"Should have included {num_tokens_to_generate} log probs but actually returned {len(output.generated_log_probs)}"

            return output

        self.mock_tokenizer.bos = self.vocab_size - 2
        # Generating random length integer prompts
        self.mock_tokenizer.tokenize.return_value = [
            random.randint(0, self.vocab_size - 1) for _ in range(random.randint(5, 10))
        ]
        # Generates some random string
        self.mock_tokenizer.detokenize.return_value = ''.join(
            random.choices(string.ascii_letters, k=random.randint(4, 10))
        )

        prompts = ["" for i in range(self.batch_size)]

        num_tokens_to_generate = 10
        sampling_params = SamplingParams(
            num_tokens_to_generate=num_tokens_to_generate, return_log_probs=True
        )
        request_ids: List[int] = [
            self.static_engine.add_request(
                prompt, add_BOS=True, sampling_params=sampling_params, streaming=True
            )
            for prompt in prompts
        ]
        stream_generators: List[AsyncGenerator[InferenceRequest, None]] = [
            self.static_engine.get_stream_generator(request_id) for request_id in request_ids
        ]
        assert all(stream_generator is not None for stream_generator in stream_generators)

        tasks = [
            asyncio.create_task(collect_stream(stream_generator, num_tokens_to_generate))
            for stream_generator in stream_generators
        ]

        await self.static_engine.run_engine_async()
        final_streamed_tokens: List[InferenceRequest] = await asyncio.gather(*tasks)
        results: List[InferenceRequest] = [
            self.static_engine.scheduler.completed_request_pool[request_id]
            for request_id in request_ids
        ]
        assert len(final_streamed_tokens) == len(results)
        for result, final_streamed_token in zip(results, final_streamed_tokens):
            assert torch.equal(
                result.generated_tokens.cpu(), final_streamed_token.generated_tokens.cpu()
            ), (
                f"result.generated_tokens={result.generated_tokens.cpu()},"
                f"final_streamed_token.generated_tokens={final_streamed_token.generated_tokens}"
            )
            assert result.generated_log_probs == final_streamed_token.generated_log_probs, (
                f"result.generated_log_probs={result.generated_log_probs}, "
                f"final_streamed_token.generated_log_probs={final_streamed_token.generated_log_probs}"
            )

    @pytest.mark.parametrize("sequence_parallel", [False, True])
    @pytest.mark.parametrize("ep_size", [1, 2])
    @pytest.mark.parametrize("pp_size", [1, 2])
    @pytest.mark.parametrize("tp_size", [1, 2])
    def test_parallel_inference(self, tp_size, pp_size, ep_size, sequence_parallel):
        if tp_size == 1 and pp_size == 1 and ep_size == 1:
            pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1")
        elif not torch.distributed.is_initialized():
            pytest.skip("Distributed not initialized")
        world_size = torch.distributed.get_world_size()
        min_world_size = tp_size * pp_size * ep_size
        if world_size < min_world_size:
            pytest.skip(f"Test requires at least {min_world_size} GPUs")
        elif tp_size == 1 and sequence_parallel:
            pytest.skip(reason="Sequence parallelism requires tp_size > 1")
        elif tp_size > 1 and ep_size > 1 and not sequence_parallel:
            pytest.skip(reason="Sequence parallelism must be used with tp_size > 1 and ep_size > 1")

        batch_size = 8

        self.setup_engine(
            engine_max_batch_size=batch_size,
            tensor_model_parallel_size=tp_size,
            pipeline_model_parallel_size=pp_size,
            expert_model_parallel_size=ep_size,
            sequence_parallel=sequence_parallel,
            legacy=True,
        )
        self.mock_tokenizer.eod = -1

        random.seed(42)

        # Generating random length integer prompts, ensuring sequence length is divisible by TP size
        self.mock_tokenizer.tokenize.return_value = [
            random.randint(0, self.vocab_size - 1) for _ in range(32)
        ]
        # Generates some random string
        self.mock_tokenizer.detokenize.return_value = ''.join(
            random.choices(string.ascii_letters, k=random.randint(4, 10))
        )

        prompts = ["sample" * (i + 1) for i in range(batch_size)]

        if sequence_parallel and (ep_size == 1 or tp_size == 1):
            with pytest.raises(NotImplementedError):
                results: List[InferenceRequest] = self.static_engine.generate(
                    prompts, sampling_params=SamplingParams(num_tokens_to_generate=10)
                )
            return
        else:
            results: List[InferenceRequest] = self.static_engine.generate(
                prompts, sampling_params=SamplingParams(num_tokens_to_generate=10)
            )

        assert len(results) == batch_size
        for result in results:
            assert (
                result.status == Status.COMPLETED
            ), f"Status should be completed but its {result.status}"
            assert result.generated_length > 0, f"Generated length should be greater than zero"
            assert result.generated_text is not None, f'Generated text should not be None'
